import numpy as np
import time
from master_problem import MasterProblem
from dca_solver import DCASolver
from separation_problem import SeparationProblem
from utils import initialize_K
from local_search import local_search
from pricing_IP import pricing_IP

'''
    Implements the column generation algorithm for Bayesian network structure learning
    Includes initialization, CG phase, row generation phase, and IP phase
    '''

class ColumnGeneration:
    def __init__(self, data, regu_Lambda, method, time_limit, save_path):
        '''
        data: Data instance
        regu_Lambda: the regularize hyperparameter
        method: 'DCA' 'DCA-HC' or 'MINLP'
        time_limit: time limit for the algorithm for each data instance; 3 hours in our experiments
        save_path: your save path, defined in run_bayesian_network.py
        -------------------------------------------------------------
        K: a dictionary, K[i]: k*n numpy array, each row is a candidate parent set pattern for node i 
        '''
        self.data = data
        self.data_type = data.data_type
        self.regu_Lambda = regu_Lambda
        self.method = method
        self.time_limit = time_limit
        self.K = None # represent the set of candidate parent sets
        self.save_path_1 = save_path + '.txt' # for saving algorithm outputs
        self.save_path_2 = save_path + '_GurobiLog.txt' # for saving Gurobi outputs
        self.evaluation_time = 0 # total evaluation time in Lovasz extension functions
        self.separation_time = 0
        self.pricing_time = 0
        self.findcycle_time = 0
        self.converge_count = 0 # to determine whether exit CG loop
        self.IP_sol_X = None

    def column_generation(self):
        '''
        This function contains the row and column generation framework for BNSL
        '''
        d = self.data
        n = d.n
        ndata = d.ndata
        regu_Lambda = self.regu_Lambda

        self.IP_obj_history = []
        self.LP_obj_history = []
        self.K = initialize_K(d) # initialize the candidate parent sets
        self.C_set = [] # clusters list
        pricing_err = 1e-3 # for determining negative reduced cost
        cg_err = 1e-3 # for determining the convergence of the RMIP objective
        count_lim = 2 # if exceed 3 iterations with RMIP obj increasement < cg_err, then treat as convergence
        have_cycle = True # initialize as True
        time_start = time.time() # use for the time limit

        # clean the log file
        with open(self.save_path_2, 'w') as f2:
            pass

        '''
        The iterations start
        '''

        while (time.time()-time_start) < self.time_limit: # not exceed time limit
            rc_negative = [True] * n # to determine whether exit pricing loop
            pricing_history = dict() # to record pricing objective values
            '''
            Start CG phase
            '''
            pricing_part_time_start = time.time()
            new_pattern_found = False # whether find any new pattern in this part, initialize as False
            while (time.time()-time_start) < self.time_limit:
                '''
                Solve RMLP and obtain optimal dual solutions
                '''
                mp = MasterProblem(d, self.data_type, self.K, self.C_set, True, regu_Lambda, self.save_path_2)
                master_model, X = mp.solve()
                
                # extract solutions from Master LP
                sol_X = [np.array([var.X for var in var_list]) for var_list in X]
                # rc_ls = [np.array([var.RC for var in var_dict.values()]) for var_dict in X]

                available_choice = {} # choice of patterns for initilization in DCA
                for i in range(n):
                    available_choice[i] = np.argmax(sol_X[i])

                duals = master_model.getAttr('Pi', master_model.getConstrs())

                '''
                Solve pricing problems for each node
                '''
                for i in range(n):
                    if rc_negative[i]: # get new pattern in last turn
                        print('Solve pricing for node ', i)

                        init_choice = available_choice[i]
                        init_pattern = self.K[i][init_choice,:] # for warm-start initialization in DCA
                        if self.method=='DCA' or self.method=='DCA-HC':
                            # dcasolver = DCASolver(d, self.data_type, self.K, i, init_pattern, duals, self.C_set, 1e5, 1e5, 1e-2, 1e-2, regu_Lambda)
                            dcasolver = DCASolver(d, self.data_type, self.K, i, init_pattern, duals, self.C_set, 1e5, 1e5, 1e-6, 1e-4, regu_Lambda)
                            self.K, pattern, pricing_obj, eval_time = dcasolver.solve()
                        if self.method=='MINLP':
                            MIP_solver = pricing_IP(d, i, duals, self.C_set, regu_Lambda)
                            pattern, pricing_obj, eval_time = MIP_solver.solve_pricing()
                            self.K[i] = np.vstack((self.K[i], pattern))

                        self.evaluation_time += eval_time # total evaluation time in Lovasz extension functions
                        # self.K, pricing_obj_ = local_search_1(self.data, i, self.K, pattern, duals, self.C_set, self.regu_Lambda)

                        unique_indices = np.unique(self.K[i],axis=0,return_index=True)[1] # avoid repetition
                        self.K[i] = self.K[i][np.sort(unique_indices)] # to remain the right order (for IP warm-start)
                        rc_negative[i] = pricing_obj < -pricing_err

                        if rc_negative[i]:
                            new_pattern_found = True # find at least one new pattern
                        pricing_history[i] = round(pricing_obj,2)
                        # print('K[i] ', K[i])

                if sum(rc_negative) <= pricing_err: # all reduced cost nonnegative
                    if self.method == 'DCA-HC': 
                        '''
                        add a local search (not included in our results in paper)
                        '''
                        for i in range(n):
                            print('Hill climbing for node ', i)
                            l = self.K[i].shape[0]
                            init_pattern = self.K[i][l-1,:] # to start with the last pattern
                            self.K = local_search(self.data, i, self.K, pattern, duals, self.C_set, self.regu_Lambda)
                    break

                print('Pricing objective\n', pricing_history.values())
                # print('Reduced cost ', sol_rc)
                print('------------------------------------')
            pricing_part_time_end = time.time()
            self.pricing_time += pricing_part_time_end - pricing_part_time_start
            print('Pricing objective\n', pricing_history.values())

            '''
            Start row generation phase
            '''
            separation_part_start_time = time.time()
            while (time.time()-time_start) < self.time_limit and new_pattern_found:
                '''
                Solve RMLP and obtain optimal primal solutions
                '''
                mp = MasterProblem(d,self.data_type,self.K,self.C_set,True,regu_Lambda, self.save_path_2)
                master_model, X = mp.solve()
                # extract solution
                sol_X = [np.array([var.X for var in var_list]) for var_list in X]
                

                '''
                Solve Separation IP
                '''
                sp = SeparationProblem(sol_X, self.K, self.C_set)
                have_new_cluster, self.C_set = sp.solve()
                if not have_new_cluster: # no violation of the constraints
                    break
            separation_part_end_time = time.time()
            self.separation_time += separation_part_end_time - separation_part_start_time
            obj = master_model.objVal # objective for the lastly solved RMLP
            self.LP_obj_history.append(round(obj,2))
        
            '''
            Start IP phase
            '''
            if time.time()-time_start < self.time_limit and new_pattern_found:
                IP_part_start_time = time.time()
                '''
                Solve RMIP with callbacks to obtain a DAG
                '''
                mp = MasterProblem(d,self.data_type,self.K,self.C_set,False,regu_Lambda, self.save_path_2)
                # print(sol_X)
                mp.set_initial_solution(self.IP_sol_X)
                master_model, self.C_set, X = mp.solve()
                
                sol_X = [np.array([var.X for var in var_list]) for var_list in X]
                self.IP_sol_X = sol_X
                # print(self.IP_sol_X)
                self.graph = np.zeros((n,n))
                # decode a DAG from the RMIP primal solutions
                for i in range(n):
                    choice = np.argmax(sol_X[i])  # the chosen pattern index
                    self.graph[i,:] = self.K[i][choice,:] # the chosen pattern
                IP_part_end_time = time.time()
                self.findcycle_time += IP_part_end_time - IP_part_start_time
                IP_obj = master_model.ObjVal # RMIP objective value
                cg_round = len(self.IP_obj_history) # count the total number of CG iterations
                if cg_round>1: # at least have an IP history
                    if abs(self.IP_obj_history[cg_round-1]-IP_obj) < cg_err:
                        self.converge_count += 1 
                    else:
                        self.converge_count = 0
                # IP_obj = self.prune_graph() # prune graphs, not included in our paper
                self.IP_obj_history.append(IP_obj)

            if not new_pattern_found:
                self.converge_count += 1

            if self.converge_count >= count_lim: # treat as convergence for CG
                break
    
            
            # continue iterations in CG
            print('Number of Patterns for each node')
            for i in range(n):
                print(len(self.K[i]),end=' ')
            print()
            print('Number of clusters: ', len(self.C_set))
            print('Graph:\n', self.graph)
            print('LP objective history:\n', self.LP_obj_history)
            print('IP objective history:\n', self.IP_obj_history)
            print('-------------------------------------------------')

        self.BIC_score = -IP_obj # calculate BIC score from objective
        if self.data_type=='C':
            self.BIC_score -= n*np.log(ndata)
        f = self.print_result()
        return self.graph, f
    
    def print_result(self):
        d = self.data
        n = d.n
        f = open(self.save_path_1, 'w')
        print('Number of Patterns for each node')
        f.write('Number of Patterns for each node\n')
        for i in range(n):
            print(len(self.K[i]),end=' ')
            f.write(str(len(self.K[i])) + ' ')
        print()
        f.write('\n')
        print('Number of clusters: ', len(self.C_set))
        f.write('Number of clusters: ' + str(len(self.C_set)))
        f.write('\n')
        print('Graph:\n', self.graph)
        print('LP objective history:\n', self.LP_obj_history)
        print('IP objective history:\n', self.IP_obj_history)
        print('BIC score: ', round(self.BIC_score, 2))
        print('Evaluation time in DCA: ', self.evaluation_time)
        f.write('Graph:\n' + str(self.graph) + '\n')
        f.write('LP objective history:\n' + str(self.LP_obj_history) + '\n')
        f.write('IP objective history:\n' + str(self.IP_obj_history) + '\n')
        f.write('BIC score: ' + str(round(self.BIC_score, 2)) + '\n')
        f.write('Evaluation time in DCA: ' + str(round(self.evaluation_time,2)) + '\n')
        f.write('Pricing time: ' + str(round(self.pricing_time,2)) + '\n')
        f.write('Separation time: ' + str(round(self.separation_time,2)) + '\n')
        f.write('Find cycle time: ' + str(round(self.findcycle_time,2)) + '\n')
        return f
        
    # def prune_graph(self):
    #     d = self.data
    #     n = d.n
    #     total_cost = 0
    #     for i in range(n):
    #         parent_pattern = self.graph[i,:]
    #         J = np.where(parent_pattern>0.5)[0].tolist()
    #         original_cost = self.cost_obj(d, J, i)
    #         self.prune_dict[str(i)+str(J)] = original_cost # to reduce repeat calculation
    #         admited_J, current_best_cost = self.prune(i, J, original_cost)
    #         total_cost += current_best_cost
    #         J = admited_J
    #         self.graph[i,:] = np.zeros(n)
    #         self.graph[i,J] = 1
    #     return total_cost
    
    # def prune(self, i, admited_J, current_best_cost): # locally search for optimal substructure
    #     d = self.data
    #     J = admited_J
    #     if len(J)==0:
    #         return admited_J, current_best_cost
    #     for j in J:
    #         pruned_J = J.copy()
    #         pruned_J.remove(j) # prune locally
    #         if str(i)+str(pruned_J) in self.prune_dict.keys():
    #             pruned_cost = self.prune_dict[str(i)+str(pruned_J)]
    #         else:
    #             pruned_cost = self.cost_obj(d, pruned_J, i)
    #             self.prune_dict[str(i)+str(pruned_J)] = pruned_cost
    #         if pruned_cost < current_best_cost: # have improvement
    #             # update by recursion
    #             admited_J, current_best_cost = self.prune(i, pruned_J, pruned_cost)
    #     return admited_J, current_best_cost
    
    # def cost_obj(self, d, J, i): # local cost
    #     n = d.n
    #     ndata = d.ndata
    #     data_type = d.data_type
    #     if data_type == 'C':
    #         cost = continuous_scores.cost
    #         local_cost_obj = np.log(cost(self.data, J, i)) * ndata / 2 + (1 + np.log(2 * np.pi)) * ndata / 2 + self.regu_Lambda * len(J)
    #     else:
    #         cost = discrete_scores.cost
    #         local_cost_obj = cost(self.data, J, i) * ndata + self.regu_Lambda*(self.data.arity[i]-1)*math.prod([self.data.arity[j] for j in J])
    #     return local_cost_obj